import copy
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from torch import optim
import torch.nn.functional as F
from load import get_P, get_Pd, get_W_lg
import random
from controlsnr import find_a_given_snr
from scipy.sparse import csr_matrix, save_npz

def simple_collate_fn(batch):
    """
    拼接 batch，适用于所有图大小一致的情况。
    返回：
      - adj: [B, N, N]
      - labels: [B, N]
    """
    adjs = [torch.tensor(sample['adj'].toarray(), dtype=torch.float32) for sample in batch]
    labels = [torch.tensor(sample['labels'], dtype=torch.long) for sample in batch]

    adj_batch = torch.stack(adjs)       # [B, N, N]
    label_batch = torch.stack(labels)   # [B, N]

    return {
        'adj': adj_batch,
        'labels': label_batch
    }

def midpoints(seq):
    return [(a + b) / 2 for a, b in zip(seq[:-1], seq[1:])]


snr_train   = [0.10, 0.50, 0.86, 1.12, 1.38, 1.64, 1.90, 2.16, 2.42, 2.94]  # 10
gamma_train = [0.30, 1.20, 3.00, 5.00]                                      # 4
C_train     = [5.0, 10.0, 15.0, 20.0, 25.0]
per_cell_tr = 1

# 验证集 = 训练集的 midpoints
snr_val   = midpoints(snr_train)    # 9
gamma_val = midpoints(gamma_train)  # 3
C_val     = midpoints(C_train)      # 4
per_cell_v = 1


# —— 测试集（=200）
snr_test = (0.60, 1.00, 1.60, 2.20, 2.80)
gamma_test = (0.15, 0.60, 1.50, 3.00)
C_test = (10.0,)
per_cell_te = 1


class Generator(object):
    def __init__(self, N_train=50, N_test=100, N_val = 50,generative_model='SBM_multiclass', p_SBM=0.8, q_SBM=0.2, n_classes=2, path_dataset='dataset',
                 num_examples_train=100, num_examples_test=10, num_examples_val=10):
        self.N_train = N_train
        self.N_test = N_test
        self.N_val = N_val

        self.generative_model = generative_model
        self.p_SBM = p_SBM
        self.q_SBM = q_SBM
        self.n_classes = n_classes
        self.path_dataset = path_dataset

        self.data_train = None
        self.data_test = None
        self.data_val = None

        self.num_examples_train = num_examples_train
        self.num_examples_test = num_examples_test
        self.num_examples_val = num_examples_val

        self.fixed_class_sizes = [
            (500, 500),
            (400, 600),
            (300, 700),
            (200, 800),
            (100, 900),
            (50, 950)
        ]

    def SBM(self, p, q, N):
        W = np.zeros((N, N))

        p_prime = 1 - np.sqrt(1 - p)
        q_prime = 1 - np.sqrt(1 - q)

        n = N // 2

        W[:n, :n] = np.random.binomial(1, p, (n, n))
        W[n:, n:] = np.random.binomial(1, p, (N-n, N-n))
        W[:n, n:] = np.random.binomial(1, q, (n, N-n))
        W[n:, :n] = np.random.binomial(1, q, (N-n, n))
        W = W * (np.ones(N) - np.eye(N))
        W = np.maximum(W, W.transpose())

        perm = torch.randperm(N).numpy()
        blockA = perm < n
        labels = blockA * 2 - 1

        W_permed = W[perm]
        W_permed = W_permed[:, perm]
        return W_permed, labels


    def SBM_multiclass(self, p, q, N, n_classes):

        p_prime = 1 - np.sqrt(1 - p)
        q_prime = 1 - np.sqrt(1 - q)

        prob_mat = np.ones((N, N)) * q_prime

        n = N // n_classes  # 基础类别大小
        remainder = N % n_classes  # 不能整除的剩余部分
        n_last = n + remainder  # 最后一类的大小

        # 先对整除部分进行块状分配
        for i in range(n_classes - 1):  # 处理前 n_classes-1 类
            prob_mat[i * n: (i + 1) * n, i * n: (i + 1) * n] = p_prime

        # 处理最后一类
        start_idx = (n_classes - 1) * n  # 最后一类的起始索引
        prob_mat[start_idx: start_idx + n_last, start_idx: start_idx + n_last] = p_prime

        # 生成邻接矩阵
        W = np.random.rand(N, N) < prob_mat
        W = W.astype(int)

        W = W * (np.ones(N) - np.eye(N))  # 移除自环
        W = np.maximum(W, W.transpose())  # 确保无向图

        # 随机打乱节点顺序
        perm = torch.randperm(N).numpy()

        # 生成类别标签
        labels =np.minimum((perm // n) , n_classes - 1)

        W_permed = W[perm]
        W_permed = W_permed[:, perm]

        #计算P矩阵的特征向量
        prob_mat_permed = prob_mat[perm][:, perm]
        # np.fill_diagonal(prob_mat_permed, 0)  # 去除自环

        eigvals, eigvecs = np.linalg.eigh(prob_mat_permed)
        idx = np.argsort(eigvals)[::-1]
        eigvecs_top = eigvecs[:, idx[:n_classes]]

        return W_permed, labels, eigvecs_top  # 返回前n_classes特征向量

    def imbalanced_SBM_multiclass(self, p, q, N, n_classes, class_sizes):
        import numpy as np
        import torch

        # 上三角采样不会放大概率，直接用目标 p, q
        p_prime = float(p)
        q_prime = float(q)

        # 构造期望矩阵（块内 p，块间 q），无自环
        prob_mat = np.full((N, N), q_prime, dtype=float)
        boundaries = np.cumsum([0] + class_sizes)
        for i in range(n_classes):
            start, end = boundaries[i], boundaries[i + 1]
            prob_mat[start:end, start:end] = p_prime
        np.fill_diagonal(prob_mat, 0.0)

        # —— 关键修改：只采样上三角，然后镜像 —— #
        W = np.zeros((N, N), dtype=np.uint8)
        iu, ju = np.triu_indices(N, k=1)
        W[iu, ju] = (np.random.rand(iu.size) < prob_mat[iu, ju]).astype(np.uint8)
        W = (W + W.T).astype(np.uint8)  # 无向化；对角仍为 0

        # 打乱节点顺序
        perm = torch.randperm(N).numpy()

        # 生成并置乱标签
        labels = np.zeros(N, dtype=int)
        for i in range(n_classes):
            start, end = boundaries[i], boundaries[i + 1]
            labels[start:end] = i
        labels = labels[perm]

        # 同步置乱矩阵
        W_permed = W[perm][:, perm]

        # 置乱后的期望矩阵用于特征分解（与 W_permed 对齐）
        prob_mat_permed = prob_mat[perm][:, perm]
        eigvals, eigvecs = np.linalg.eigh(prob_mat_permed)
        idx = np.argsort(eigvals)[::-1]
        eigvecs_top = eigvecs[:, idx[:n_classes]]

        return W_permed, labels, eigvecs_top

    # def imbalanced_SBM_multiclass_hetero(p_in, q, N, n_classes, class_sizes, *,
    #                                      shuffle=True, return_eigs=True, rng=None,
    #                                      dtype=np.uint8):
    #     """
    #     生成不平衡多社区 SBM：
    #       - 每个社区 i 有自己的块内概率 p_in[i] = p_{ii}
    #       - 块间概率可为标量 q（所有 i!=j 相同），或对称矩阵 q[i,j]
    #
    #     参数
    #     ----
    #     p_in : array-like, shape (k,)
    #         各社区内部连接概率 [p_11, p_22, ..., p_kk]，必须在 (0,1) 内。
    #     q : float or array-like (k,k)
    #         若为标量，则所有 i!=j 的块间概率都为该值；
    #         若为矩阵，则使用 q[i,j]（要求对称且对角元素忽略）。
    #     N : int
    #         总节点数。
    #     n_classes : int
    #         社区数 k。
    #     class_sizes : list[int] or array-like, length k
    #         每个社区的大小，和应为 N。
    #     shuffle : bool
    #         是否随机置乱节点顺序后返回。
    #     return_eigs : bool
    #         是否返回置乱后期望矩阵的前 k 个特征向量（按特征值从大到小）。
    #     rng : int or np.random.Generator or None
    #         随机种子或生成器。
    #     dtype : numpy dtype
    #         邻接矩阵类型，默认 np.uint8。
    #
    #     返回
    #     ----
    #     W_out : (N,N) ndarray[dtype]
    #         采样得到的无向无自环邻接矩阵。
    #     labels_out : (N,) ndarray[int]
    #         置乱后的标签（0..k-1）。
    #     eigvecs_top : (N,k) ndarray[float] or None
    #         置乱后期望矩阵的前 k 个特征向量（若 return_eigs=False 则为 None）。
    #     """
    #     # ---- 检查与准备 ----
    #     p_in = np.asarray(p_in, dtype=float).reshape(-1)
    #     assert len(p_in) == n_classes, "p_in 长度必须等于 n_classes"
    #     assert np.all((p_in >= 0) & (p_in <= 1)), "p_in 必须在 [0,1]"
    #
    #     class_sizes = np.asarray(class_sizes, dtype=int).reshape(-1)
    #     assert len(class_sizes) == n_classes, "class_sizes 长度必须等于 n_classes"
    #     assert class_sizes.sum() == N, "class_sizes 之和必须等于 N"
    #
    #     if np.isscalar(q):
    #         use_q_matrix = False
    #         q_scalar = float(q)
    #         assert 0 <= q_scalar <= 1, "q 标量必须在 [0,1]"
    #     else:
    #         use_q_matrix = True
    #         q_mat = np.asarray(q, dtype=float)
    #         assert q_mat.shape == (n_classes, n_classes), "q 矩阵形状必须为 (k,k)"
    #         # 强制对称
    #         q_mat = 0.5 * (q_mat + q_mat.T)
    #         # 对角线不用（由 p_in 决定），可强制为 0 以避免误用
    #         np.fill_diagonal(q_mat, 0.0)
    #         assert np.all((q_mat >= 0) & (q_mat <= 1)), "q 矩阵元素必须在 [0,1]"
    #
    #     # 建立每个社区的区间边界
    #     boundaries = np.cumsum([0] + class_sizes.tolist())
    #
    #     # ---- 构造期望概率矩阵（未置乱）----
    #     prob_mat = np.zeros((N, N), dtype=float)
    #
    #     # 块内：使用各自的 p_in[i]
    #     for i in range(n_classes):
    #         s, e = boundaries[i], boundaries[i + 1]
    #         if e - s >= 2:
    #             prob_mat[s:e, s:e] = p_in[i]
    #         else:
    #             # 单节点块，块内不会产生边，跳过也可
    #             prob_mat[s:e, s:e] = 0.0
    #
    #     # 块间：q 标量或 q 矩阵
    #     for i in range(n_classes):
    #         si, ei = boundaries[i], boundaries[i + 1]
    #         for j in range(i + 1, n_classes):
    #             sj, ej = boundaries[j], boundaries[j + 1]
    #             pij = q_mat[i, j] if use_q_matrix else q_scalar
    #             if pij < 0 or pij > 1:
    #                 raise ValueError("块间概率不在 [0,1]")
    #             prob_mat[si:ei, sj:ej] = pij
    #             prob_mat[sj:ej, si:ei] = pij
    #
    #     # 去自环
    #     np.fill_diagonal(prob_mat, 0.0)
    #
    #     # ---- 上三角采样并对称化 ----
    #     if isinstance(rng, np.random.Generator):
    #         gen = rng
    #     else:
    #         gen = np.random.default_rng(rng)
    #
    #     iu, ju = np.triu_indices(N, k=1)
    #     W = np.zeros((N, N), dtype=dtype)
    #     W[iu, ju] = (gen.random(iu.size) < prob_mat[iu, ju]).astype(dtype)
    #     W = (W + W.T).astype(dtype)
    #
    #     # 原始（未置乱）标签：按块顺序 0..k-1
    #     labels = np.zeros(N, dtype=int)
    #     for i in range(n_classes):
    #         s, e = boundaries[i], boundaries[i + 1]
    #         labels[s:e] = i
    #
    #     # ---- 置乱（可选）----
    #     if shuffle:
    #         perm = torch.randperm(N).numpy()
    #         W_out = W[perm][:, perm]
    #         labels_out = labels[perm]
    #         prob_mat_out = prob_mat[perm][:, perm]
    #     else:
    #         W_out = W
    #         labels_out = labels
    #         prob_mat_out = prob_mat
    #
    #     # ---- 可选：返回置乱后期望矩阵的前 k 个特征向量 ----
    #     if return_eigs:
    #         # eigh 返回升序，这里取从大到小的前 k 个
    #         eigvals, eigvecs = np.linalg.eigh(prob_mat_out)
    #         idx = np.argsort(eigvals)[::-1][:n_classes]
    #         eigvecs_top = eigvecs[:, idx]
    #     else:
    #         eigvecs_top = None
    #
    #     return W_out, labels_out, eigvecs_top

    def create_dataset_random_otf(self, directory, mode='train', C=10, min_size=50):
        """
        生成随机 SBM_multiclass 图（不使用固定 p/q），并逐图保存为稀疏格式 .npz 文件。
        """
        if not os.path.exists(directory):
            os.makedirs(directory)

        if mode == 'train':
            graph_size = self.N_train
            num_graphs = self.num_examples_train
        elif mode == 'test':
            graph_size = self.N_test
            num_graphs = self.num_examples_test
        elif mode == 'val':
            graph_size = self.N_val
            num_graphs = self.num_examples_val
        else:
            raise ValueError(f"Unsupported mode: {mode}")

        for i in range(num_graphs):
            # Step 1: SNR 控制边密度
            a_low, b_low = find_a_given_snr(0.5, self.n_classes, C)
            a_high, b_high = find_a_given_snr(2, self.n_classes, C)
            lower_bound = a_low / b_low
            upper_bound = a_high / b_high

            # Step 2: 生成 SBM 参数和图
            p, q, class_sizes, snr = self.random_imbalanced_SBM_generator_balanced_sampling(
                N=graph_size,
                n_classes=self.n_classes,
                C=C,
                alpha_range=(lower_bound, upper_bound),
                min_size=min_size
            )

            W_dense, labels, eigvecs_top = self.imbalanced_SBM_multiclass(
                p, q, graph_size, self.n_classes,class_sizes
            )

            # Step 3: 稀疏化邻接矩阵
            W_sparse = csr_matrix(W_dense)

            # Step 4: 保存为稀疏 .npz 格式
            graph_path = os.path.join(directory, f"graph_{i:04d}.npz")
            np.savez_compressed(
                graph_path,
                adj_data=W_sparse.data,
                adj_indices=W_sparse.indices,
                adj_indptr=W_sparse.indptr,
                adj_shape=W_sparse.shape,
                labels=labels,
                p=p,
                q=q,
                snr=snr,
                class_sizes=np.array(class_sizes)
            )

        print(f" {mode} 数据集已保存到目录: {directory}")

        # 不再加载所有图进内存
        if mode == 'train':
            self.data_train = directory
        elif mode == 'test':
            self.data_test = directory
        elif mode == 'val':
            self.data_val = directory

    # def create_dataset_random_otf(
    #         self, directory, mode='train', C=10, min_size=50, *,
    #         C_list=None, C_range=None, seed=None
    # ):
    #     """
    #     生成随机 SBM_multiclass 图并逐图保存为稀疏格式 .npz 文件。
    #     - 若传入 C_list，则每张图随机从列表里抽一个 C；
    #     - 若传入 C_range=(low, high)，则从区间内随机采样一个 C；
    #     - 否则使用固定的 C。
    #     """
    #     if not os.path.exists(directory):
    #         os.makedirs(directory)
    #
    #     if mode == 'train':
    #         graph_size = self.N_train
    #         num_graphs = self.num_examples_train
    #     elif mode == 'test':
    #         graph_size = self.N_test
    #         num_graphs = self.num_examples_test
    #     elif mode == 'val':
    #         graph_size = self.N_val
    #         num_graphs = self.num_examples_val
    #     else:
    #         raise ValueError(f"Unsupported mode: {mode}")
    #
    #     rng = np.random.default_rng(seed)
    #
    #     for i in range(num_graphs):
    #         # === 新增：随机选择 C ===
    #         if C_list is not None:
    #             C_choice = float(rng.choice(C_list))
    #         elif C_range is not None:
    #             C_choice = float(rng.uniform(*C_range))  # 从区间随机选
    #         else:
    #             C_choice = C
    #
    #         # Step 1: SNR 控制边密度（保持你原逻辑）
    #         a_low, b_low = find_a_given_snr(0.1, self.n_classes, C_choice)
    #         a_high, b_high = find_a_given_snr(1.5, self.n_classes, C_choice)
    #         lower_bound = a_low / b_low
    #         upper_bound = a_high / b_high
    #
    #         # Step 2: 生成 SBM 参数和图
    #         p, q, class_sizes, snr = self.random_imbalanced_SBM_generator_balanced_sampling(
    #             N=graph_size,
    #             n_classes=self.n_classes,
    #             C=C_choice,
    #             alpha_range=(lower_bound, upper_bound),
    #             min_size=min_size
    #         )
    #
    #         # print(p, q, class_sizes, snr, C_choice)
    #
    #         W_dense, labels, eigvecs_top = self.imbalanced_SBM_multiclass(
    #             p, q, graph_size, self.n_classes, class_sizes
    #         )
    #
    #         # Step 3: 稀疏化邻接矩阵
    #         W_sparse = csr_matrix(W_dense)
    #
    #         # Step 4: 保存为稀疏 .npz 文件
    #         graph_path = os.path.join(directory, f"graph_{i:04d}.npz")
    #         np.savez_compressed(
    #             graph_path,
    #             adj_data=W_sparse.data,
    #             adj_indices=W_sparse.indices,
    #             adj_indptr=W_sparse.indptr,
    #             adj_shape=W_sparse.shape,
    #             labels=labels,
    #             p=p,
    #             q=q,
    #             snr=snr,
    #             class_sizes=np.array(class_sizes),
    #             C=np.float32(C_choice)
    #         )
    #
    #     print(f"{mode} 数据集已保存到目录: {directory}")
    #
    #     if mode == 'train':
    #         self.data_train = directory
    #     elif mode == 'test':
    #         self.data_test = directory
    #     elif mode == 'val':
    #         self.data_val = directory

    # def create_dataset_random_otf(
    #         self, directory, mode='train', C=10, min_size=50, *,
    #         C_list=None, C_range=None, seed=42
    # ):
    #     """
    #     方案A（root-gap 控难度）版本：p=a*log(n)/n, q=b*log(n)/n。
    #     - 平均度常数 C 在 [5,20] 内随机采样（若传入 C_list/C_range，则以传参为准）；
    #     - 难度用 g = (sqrt(a) - sqrt(b))^2 控制（与对数平均度 exact-recovery 阈值一致）；
    #     - 训练：阈值附近占大头；测试：范围略放宽；验证（若用）可等距网格；
    #     - 依赖：self.imbalanced_SBM_multiclass(p,q,N,k,class_sizes) -> (W_dense, labels, _)
    #     参考阈值：ABH'14; 教材/讲义同样给出两类 g>2（或一般化 CH 门槛）。
    #     """
    #
    #     if not os.path.exists(directory):
    #         os.makedirs(directory)
    #
    #     # --- 规模 ---
    #     if mode == 'train':
    #         graph_size = self.N_train
    #         num_graphs = self.num_examples_train
    #     elif mode == 'test':
    #         graph_size = self.N_test
    #         num_graphs = self.num_examples_test
    #     elif mode == 'val':
    #         graph_size = self.N_val
    #         num_graphs = self.num_examples_val
    #     else:
    #         raise ValueError(f"Unsupported mode: {mode}")
    #
    #     # --- 随机数生成器 ---
    #     if seed is None:
    #         # 如果没传，就用系统随机
    #         rng = np.random.default_rng(None)
    #     else:
    #         # 给不同 mode 加偏移，避免生成相同的图
    #         if mode == 'train':
    #             rng = np.random.default_rng(seed + 0)
    #         elif mode == 'val':
    #             rng = np.random.default_rng(seed + 1)
    #         elif mode == 'test':
    #             rng = np.random.default_rng(seed + 2)
    #         else:
    #             raise ValueError(f"Unsupported mode: {mode}")
    #
    #     k = int(self.n_classes)
    #     n = int(graph_size)
    #     logn = float(np.log(n))
    #
    #     # === (C,k,g) -> (a,b) 正确封闭解 ===
    #     # 由 x=√a, y=√b, 约束：x^2+(k-1)y^2=C 与 (x-y)^2=g
    #     # 推得：
    #     #   √b = ( √(kC - (k-1)g) - √g ) / k
    #     #   √a = ( √(kC - (k-1)g) + (k-1)√g ) / k
    #     #   a=(√a)^2, b=(√b)^2
    #     def ab_from_rootgap(C_choice: float, k: int, g: float):
    #         if not (0.0 < g < C_choice):
    #             raise ValueError("root-gap g must satisfy 0 < g < C")
    #         import math
    #         G = math.sqrt(g)
    #         X = math.sqrt(k * C_choice - (k - 1) * g)  # 实数：因 g < C
    #         sqrt_b = (X - G) / k
    #         sqrt_a = (X + (k - 1) * G) / k
    #         a = max(0.0, sqrt_a * sqrt_a)
    #         b = max(0.0, sqrt_b * sqrt_b)
    #         return a, b
    #
    #     # === root-gap 采样策略 ===
    #     def sample_rootgap_train(m):
    #         # 主战区（阈值附近）+ 两侧少量；二类阈值~2，多类量级~k
    #         bins = [(0.8, 1.4, 0.55), (1.4, 2.5, 0.25), (2.5, 3.5, 0.10), (0.4, 0.8, 0.10)]
    #         counts = rng.multinomial(m, [w for _, _, w in bins])
    #         parts = [rng.uniform(lo, hi, size=c) for (lo, hi, _), c in zip(bins, counts) if c > 0]
    #         return np.concatenate(parts) if parts else np.array([], dtype=float)
    #
    #     def sample_rootgap_test(m):
    #         bins = [(0.8, 1.6, 0.60), (1.6, 2.8, 0.25), (2.8, 3.6, 0.10), (0.4, 0.8, 0.05)]
    #         counts = rng.multinomial(m, [w for _, _, w in bins])
    #         parts = [rng.uniform(lo, hi, size=c) for (lo, hi, _), c in zip(bins, counts) if c > 0]
    #         return np.concatenate(parts) if parts else np.array([], dtype=float)
    #
    #     def sample_rootgap_val(m):
    #         return np.linspace(0.9, 3, num=max(m, 2))[:m]
    #
    #     if mode == 'train':
    #         g_list = sample_rootgap_train(num_graphs)
    #     elif mode == 'test':
    #         g_list = sample_rootgap_test(num_graphs)
    #     else:
    #         g_list = sample_rootgap_val(num_graphs)
    #
    #     # === 平均度 C 采样（与难度解耦）；默认 Uniform[5,20] ===
    #     def sample_C():
    #         if C_list is not None:
    #             return float(rng.choice(C_list))
    #         elif C_range is not None:
    #             return float(rng.uniform(*C_range))
    #         else:
    #             return float(rng.uniform(5.0, 20.0))  # 你的要求
    #
    #     # === 类别大小（Dirichlet 不平衡，带最小规模） ===
    #     def sample_class_sizes(n_nodes, n_classes, alpha=1.0, min_sz=5):
    #         prop = rng.dirichlet(np.ones(n_classes) * alpha)
    #         prop = np.maximum(prop, (min_sz + 1e-9) / n_nodes)
    #         prop = prop / prop.sum()
    #         sizes = np.floor(prop * n_nodes).astype(int)
    #         diff = n_nodes - sizes.sum()
    #         if diff > 0:
    #             idx = np.argsort(-prop)[:diff]
    #             sizes[idx] += 1
    #         return sizes.tolist()
    #
    #     for i, g_target in enumerate(g_list):
    #         # 1) 采 C
    #         C_choice = sample_C()
    #
    #         # 2) (C,k,g) -> (a,b)；轻微抖动 a/b 比防刚性（可选）
    #         a, b = ab_from_rootgap(C_choice, k, float(g_target))
    #         r = a / b if b > 0 else 1.0
    #         r *= rng.uniform(0.95, 1.05)  # ±5% 抖动
    #         b = C_choice / (r + k - 1.0)
    #         a = r * b
    #
    #         # 3) p,q （对数平均度）
    #         p = float(a * logn / n)
    #         q = float(b * logn / n)
    #
    #         # 4) 类别大小（混合不平衡强度；你也可固定 alpha=1.0）
    #         alpha_choice = rng.choice([500, 1], p=[0.9,0.1])
    #         class_sizes = sample_class_sizes(n, k, alpha=float(alpha_choice), min_sz=min_size)
    #
    #         # 5) 采样图
    #         W_dense, labels, _ = self.imbalanced_SBM_multiclass(p, q, n, k, class_sizes)
    #
    #         # 6) 记录“实现的”root-gap（用最终 a,b 回算）
    #         rootgap_sq_realized = (np.sqrt(max(a, 1e-12)) - np.sqrt(max(b, 1e-12))) ** 2
    #
    #         # 7) 保存
    #         W_sparse = csr_matrix(W_dense)
    #         graph_path = os.path.join(directory, f"graph_{i:04d}.npz")
    #         np.savez_compressed(
    #             graph_path,
    #             adj_data=W_sparse.data,
    #             adj_indices=W_sparse.indices,
    #             adj_indptr=W_sparse.indptr,
    #             adj_shape=W_sparse.shape,
    #             labels=np.asarray(labels, dtype=np.int32),
    #             p=np.float32(p), q=np.float32(q),
    #             a=np.float32(a), b=np.float32(b),
    #             C=np.float32(C_choice),
    #             target_rootgap_sq=np.float32(g_target),
    #             rootgap_sq=np.float32(rootgap_sq_realized),
    #             class_sizes=np.asarray(class_sizes, dtype=np.int32),
    #             n=np.int32(n), k=np.int32(k),
    #         )
    #         # print(p,q,C_choice,g_target,class_sizes)
    #
    #     print(f"{mode} 数据集已保存到目录: {directory}")
    #
    #     if mode == 'train':
    #         self.data_train = directory
    #     elif mode == 'test':
    #         self.data_test = directory
    #     elif mode == 'val':
    #         self.data_val = directory

    # def create_dataset_random_otf_political_blog(self,directory):
    # # —— 训练集（≈2000）
    # def prepare_data(self):
    #     def get_npz_dataset(path, mode):
    #         if not os.path.exists(path):
    #             os.makedirs(path)
    #             print(f"[创建数据集] {mode} 数据目录不存在，已新建：{path}")
    #
    #         npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
    #         if not npz_files:
    #             print(f"[创建数据集] {mode} 数据未找到，开始生成...")
    #             self.create_dataset_grid(path, mode=mode, C = 10, min_size=50)
    #             npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
    #         else:
    #             print(f"[读取数据] {mode} 集已存在，共 {len(npz_files)} 张图：{path}")
    #
    #         # 返回路径列表
    #         return [os.path.join(path, f) for f in npz_files]
    #
    #     train_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gstr{self.N_train}_numtr{self.num_examples_train}"
    #     test_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gste{self.N_test}_numte{self.num_examples_test}"
    #     val_dir = f"{self.generative_model}_nc{self.n_classes}_rand_val{self.N_val}_numval{self.num_examples_val}"
    #
    #     train_path = os.path.join(self.path_dataset, train_dir)
    #     test_path = os.path.join(self.path_dataset, test_dir)
    #     val_path = os.path.join(self.path_dataset, val_dir)
    #
    #     self.data_train = get_npz_dataset(train_path, 'train')
    #     self.data_test = get_npz_dataset(test_path, 'test')
    #     self.data_val = get_npz_dataset(val_path, 'val')

    def prepare_data(self):
        def get_npz_dataset(path, mode, *, snr_grid, gamma_grid, C_grid, per_cell, min_size=50, base_seed=0):
            if not os.path.exists(path):
                os.makedirs(path)
                print(f"[创建数据集] {mode} 数据目录不存在，已新建：{path}")

            npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            if not npz_files:
                print(f"[创建数据集] {mode} 数据未找到，开始生成...")
                self.create_dataset_grid(
                    path, mode=mode,
                    snr_grid=snr_grid,
                    gamma_grid=gamma_grid,
                    C_grid=C_grid,
                    per_cell=per_cell,
                    min_size=min_size,
                    base_seed=base_seed
                )
                npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            else:
                print(f"[读取数据] {mode} 集已存在，共 {len(npz_files)} 张图：{path}")
            return [os.path.join(path, f) for f in npz_files]

        # ==== 目录 ====
        train_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gstr{self.N_train}_numtr{self.num_examples_train}"
        test_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gste{self.N_test}_numte{self.num_examples_test}"
        val_dir = f"{self.generative_model}_nc{self.n_classes}_rand_val{self.N_val}_numval{self.num_examples_val}"

        train_path = os.path.join(self.path_dataset, train_dir)
        test_path = os.path.join(self.path_dataset, test_dir)
        val_path = os.path.join(self.path_dataset, val_dir)

        # ==== 采用上面的三套参数 ====
        self.data_train = get_npz_dataset(
            train_path, 'train',
            snr_grid=snr_train, gamma_grid=gamma_train, C_grid=C_train, per_cell=per_cell_tr,
            min_size=50, base_seed=123
        )
        self.data_val = get_npz_dataset(
            val_path, 'val',
            snr_grid=snr_val, gamma_grid=gamma_val, C_grid=C_val, per_cell=per_cell_v,
            min_size=50, base_seed=2025
        )
        self.data_test = get_npz_dataset(
            test_path, 'test',
            snr_grid=snr_test, gamma_grid=gamma_test, C_grid=C_test, per_cell=per_cell_te,
            min_size=50, base_seed=31415
        )


    def sample_single(self, i, is_training=True):
        if is_training:
            dataset = self.data_train
        else:
            dataset = self.data_test
        example = dataset[i]
        if (self.generative_model == 'SBM_multiclass'):
            W_np = example['W']
            labels = np.expand_dims(example['labels'], 0)
            labels_var = torch.from_numpy(labels)
            if is_training:
                labels_var.requires_grad = True
            return W_np, labels_var


    def sample_otf_single(self, is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)
        elif self.generative_model == 'SBM_multiclass':
            W, labels,eigvecs_top = self.SBM_multiclass(self.p_SBM, self.q_SBM, N, self.n_classes)
        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top

    def imbalanced_sample_otf_single(self, class_sizes , is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)
        elif self.generative_model == 'SBM_multiclass':
            W, labels,eigvecs_top = self.imbalanced_SBM_multiclass(self.p_SBM, self.q_SBM, N, self.n_classes, class_sizes)
        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top


    def random_sample_otf_single(self, C = 10 ,is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)

        elif self.generative_model == 'SBM_multiclass':
            a_low, b_low = find_a_given_snr(0.1, self.n_classes, C)
            a_high, b_high = find_a_given_snr(1, self.n_classes, C)

            lower_bound = a_low / b_low
            upper_bound = a_high / b_high

            if lower_bound > upper_bound:
                lower_bound, upper_bound = upper_bound, lower_bound

            p, q, class_sizes, snr = self.random_imbalanced_SBM_generator_balanced_sampling(
                N=N,
                n_classes=self.n_classes,
                C=C,
                alpha_range=(lower_bound, upper_bound),
                min_size= 20
            )
            W, labels,eigvecs_top = self.imbalanced_SBM_multiclass(p, q, N, self.n_classes, class_sizes)

        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top, snr, class_sizes


    def random_imbalanced_SBM_generator_balanced_sampling(self, N, n_classes, C, *,
                                        alpha_range=(1.3, 2.8),
                                        min_size=5):
        """
        随机生成 SBM 模型的参数，社区大小为随机比例但总和为 N。
        返回 p, q, class_sizes, a, b, snr。
        """
        assert N >= min_size * n_classes

        # Step 1: 随机生成 a > b，使得 a + (k - 1) * b = C
        alpha = np.random.uniform(*alpha_range)
        b = C / (alpha + (n_classes - 1))
        a = alpha * b

        # Step 2: 计算边连接概率
        logn = np.log(N)
        p = a * logn / N
        q = b * logn / N

        # ✅ Step 3: 使用 Dirichlet 生成 class_sizes
        remaining = N - min_size * n_classes
        probs = np.random.dirichlet(np.ones(n_classes))  # 总和为1的概率向量
        extras = np.random.multinomial(remaining, probs)
        class_sizes = [min_size + e for e in extras]

        # Step 4: 计算 SNR
        snr = (a - b) ** 2 / (n_classes * (a + (n_classes - 1) * b))

        return p, q, class_sizes, snr

    def _sample_class_sizes_dirichlet(self, N, n_classes, gamma, min_size, rng,
                                      gamma_jitter=0.5):
        """按 Dirichlet(gamma) 采样类比例 + min_size 下界，支持 gamma 抖动。"""
        assert N >= min_size * n_classes
        remaining = N - min_size * n_classes

        # 对 gamma 做整体抖动：γ' = γ * U(1-δ, 1+δ)
        if gamma_jitter and gamma_jitter > 0:
            mult = rng.uniform(1.0 - gamma_jitter, 1.0 + gamma_jitter)
            gamma_used = gamma * mult
        else:
            gamma_used = gamma

        probs = rng.dirichlet(np.full(n_classes, gamma_used, dtype=float))
        extras = rng.multinomial(remaining, probs)
        return [min_size + int(e) for e in extras]  # 也可以返回实际使用的 γ

    # def gen_one_sbm_by_targets(
    #         self, N, n_classes, C, target_snr, gamma, min_size=5, *, rng=None,
    #         heterophily=False, hetero_prob=None,
    #         # === 反刚性（抖动）相关参数 ===
    #         r_jitter=0.05,  # 对 r = a/b 进行 ±5% 抖动；设为 0 关闭
    #         keep_assortativity=True,  # True：保持同配/异配属性不变
    #         pq_jitter=None  # 可选：对最终 p、q 再做一次微抖动，比如 (0.02, 0.02)
    # ):
    #     """
    #     用目标 SNR + Dirichlet(gamma) 生成一张 SBM 图；支持“反刚性”抖动。
    #     返回: p, q, class_sizes, snr_real, a, b, gamma, is_hetero
    #     """
    #     rng = np.random.default_rng() if rng is None else rng
    #
    #     # 1) 基准 a,b（默认 a>b 为同配）
    #     a0, b0 = find_a_given_snr(target_snr, n_classes, C)  # 你已有
    #     r0 = a0 / b0 if b0 > 0 else 1.0
    #
    #     # 2) 是否异配（全局对调 a,b）
    #     if hetero_prob is not None:
    #         is_hetero = bool(rng.random() < float(hetero_prob))
    #     else:
    #         is_hetero = bool(heterophily)
    #
    #     if is_hetero:
    #         a0, b0 = b0, a0
    #         r0 = a0 / b0 if b0 > 0 else 1.0  # 重新计算 r0
    #
    #     # 3) 反刚性：对 r 做 ±r_jitter 抖动，并回代到 a,b 同时保持 a+(k-1)b=C
    #     if r_jitter and r_jitter > 0:
    #         mult = rng.uniform(1.0 - float(r_jitter), 1.0 + float(r_jitter))
    #         r = r0 * mult
    #     else:
    #         r = r0
    #
    #     # 回代（保持 C 不变）： b = C / (r + k - 1), a = r * b
    #     b = C / (r + n_classes - 1.0)
    #     a = r * b
    #
    #     # 4) 若需要保持同配/异配属性，则做一次保护（防止抖动把 a 与 b 的大小关系翻转）
    #     if keep_assortativity:
    #         if is_hetero:
    #             # 异配应满足 a<b；若被抖动破坏则交换
    #             if not (a < b):
    #                 a, b = min(a, b), max(a, b)
    #         else:
    #             # 同配应满足 a>b；若被抖动破坏则交换
    #             if not (a > b):
    #                 a, b = max(a, b), min(a, b)
    #
    #     # 5) p, q（对数平均度）
    #     logn = np.log(N)
    #     p = float(a * logn / N)
    #     q = float(b * logn / N)
    #
    #     # 可选：对 p、q 再做一次很小的独立抖动（“反刚性加强版”）
    #     if pq_jitter is not None:
    #         pj, qj = pq_jitter
    #         if pj and pj > 0:
    #             p *= rng.uniform(1.0 - float(pj), 1.0 + float(pj))
    #         if qj and qj > 0:
    #             q *= rng.uniform(1.0 - float(qj), 1.0 + float(qj))
    #
    #     # 6) 类大小（Dirichlet 控不平衡）
    #     class_sizes = self._sample_class_sizes_dirichlet(N, n_classes, gamma, min_size, rng)
    #
    #     # 7) 实际 SNR（记录一下，方便画图/筛选）
    #     snr_real = (a - b) ** 2 / (n_classes * (a + (n_classes - 1) * b))
    #
    #     return p, q, class_sizes, snr_real, a, b, gamma

    def gen_one_sbm_by_targets(
            self, N, n_classes, C, target_snr, gamma, min_size=5, *, rng=None,
            heterophily=False, hetero_prob=None,
            # === 反刚性（抖动）相关参数 ===
            r_jitter=0.05,  # 对 r = a/b 做 ±5% 抖动；设 0 关闭
            keep_assortativity=True,  # 保持同配/异配属性
            pq_jitter=None,  # 可选：对最终 p、q 做小幅独立抖动 (pj, qj)
            # === 新增：C 的轻微抖动 ===
            C_jitter=0.1,  # 默认不抖；建议 0.03~0.10
            C_jitter_mode='relative',  # 'relative' 或 'absolute'
            b_floor=1e-6  # 保障 b>0 的安全下界
    ):
        """
        用目标 SNR + Dirichlet(gamma) 生成一张 SBM 图；支持 C 与 r 的轻微抖动。
        返回: p, q, class_sizes, snr_real, a, b, gamma, is_hetero, C_used
        """
        import numpy as np
        rng = np.random.default_rng() if rng is None else rng

        # === 0) 先对 C 做轻微抖动（若启用） ===
        C_used = float(C)
        if C_jitter and C_jitter > 0:
            if C_jitter_mode == 'relative':
                mult = rng.uniform(1.0 - float(C_jitter), 1.0 + float(C_jitter))
                C_used = C_used * mult
            elif C_jitter_mode == 'absolute':
                C_used = C_used + rng.uniform(-float(C_jitter), float(C_jitter))
            # 防守：C 不可过小
            C_used = max(C_used, 1e-6)

        # === 1) 基准 a,b（默认同配 a>b）——用 C_used 求解 ===
        a0, b0 = find_a_given_snr(target_snr, n_classes, C_used)  # 你已有
        r0 = a0 / b0 if b0 > 0 else 1.0

        # === 2) 同/异配选择 ===
        if hetero_prob is not None:
            is_hetero = bool(rng.random() < float(hetero_prob))
        else:
            is_hetero = bool(heterophily)
        if is_hetero:
            a0, b0 = b0, a0
            r0 = a0 / b0 if b0 > 0 else 1.0

        # === 3) 对 r 做 ±r_jitter 抖动，并保持 a+(k-1)b=C_used ===
        r = r0
        if r_jitter and r_jitter > 0:
            r *= rng.uniform(1.0 - float(r_jitter), 1.0 + float(r_jitter))
        # 回代： b = C' / (r + k - 1), a = r * b
        b = C_used / (r + n_classes - 1.0)
        a = r * b

        # === 4) 保护同/异配属性 ===
        if keep_assortativity:
            if is_hetero:
                # 异配要求 a<b
                if not (a < b):
                    a, b = min(a, b), max(a, b)
            else:
                # 同配要求 a>b
                if not (a > b):
                    a, b = max(a, b), min(a, b)

        # === 4.5) 安全下界，避免 b 太小、数值不稳 ===
        if b < b_floor:
            b = b_floor
            a = C_used - (n_classes - 1) * b
            # 若破坏了同/异配关系，可视情况微调 b_floor 或降 C_jitter/r_jitter

        # === 5) 计算 p,q（稀疏区间：log n / n） ===
        logn = np.log(N)
        p = float(a * logn / N)
        q = float(b * logn / N)

        # 可选：对 p,q 再做一丁点独立抖动（极小幅度），增强去刚性
        if pq_jitter is not None:
            pj, qj = pq_jitter
            if pj and pj > 0:
                p *= rng.uniform(1.0 - float(pj), 1.0 + float(pj))
            if qj and qj > 0:
                q *= rng.uniform(1.0 - float(qj), 1.0 + float(qj))

        # === 6) 类大小 ===
        class_sizes = self._sample_class_sizes_dirichlet(N, n_classes, gamma, min_size, rng)

        # === 7) 实际 SNR（记录） ===
        snr_real = (a - b) ** 2 / (n_classes * (a + (n_classes - 1) * b))

        return p, q, class_sizes, snr_real, a, b, gamma, is_hetero, C_used

    def create_dataset_grid(self, directory, mode='train', *,
                            snr_grid=(0.6, 0.9, 1.1, 1.3, 1.6, 2.0, 2.5, 3.0),
                            gamma_grid=(0.15, 0.3, 0.6, 1.0, 2.0),
                            C_grid=(10.0,),  # 如需多密度可设为 (6, 10, 14)
                            per_cell=20,
                            min_size=5,
                            base_seed=0):
        """
        在 (SNR × gamma × C) 的笛卡尔网格上生成数据；每个网格点生成 per_cell 张图。
        文件名写入网格信息；.npz 中保存全部元数据，保证复现性与分析便利。
        """
        os.makedirs(directory, exist_ok=True)

        if mode == 'train':
            N = self.N_train
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_train = directory
        elif mode == 'val':
            N = self.N_val
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_val = directory
        elif mode == 'test':
            N = self.N_test
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_test = directory
        else:
            raise ValueError(f"Unsupported mode: {mode}")

        idx = 0
        for c_idx, C in enumerate(C_grid):
            for s_idx, snr_target in enumerate(snr_grid):
                for g_idx, gamma in enumerate(gamma_grid):
                    # 每个格点单独的 RNG，保证可复现
                    cell_seed = base_seed + (c_idx * 10_000_000
                                             + s_idx * 10_000
                                             + g_idx * 100)
                    rng = np.random.default_rng(cell_seed)

                    for rep in range(per_cell):

                        p, q, class_sizes, snr_real, a, b, gamma, is_hetero, C_used = self.gen_one_sbm_by_targets(
                            N=N, n_classes=self.n_classes, C=C,
                            target_snr=snr_target, gamma=gamma,
                            min_size=min_size, rng=rng
                        )

                        # print( p, q, class_sizes, snr_real, a, b, gamma, is_hetero, C_used)

                        # print(p, q, class_sizes, snr_real, a, b, gamma)
                        # 生成图（你已有）
                        W_dense, labels, eigvecs_top = self.imbalanced_SBM_multiclass(
                            p, q, N, self.n_classes, class_sizes
                        )
                        W_sparse = csr_matrix(W_dense)

                        fname = (f"{mode}_i{idx:05d}"
                                 f"__C{C:.2f}__snr{snr_target:.3f}"
                                 f"__g{gamma:.3f}__rep{rep:02d}.npz")
                        path = os.path.join(directory, fname)

                        np.savez_compressed(
                            path,
                            adj_data=W_sparse.data,
                            adj_indices=W_sparse.indices,
                            adj_indptr=W_sparse.indptr,
                            adj_shape=W_sparse.shape,
                            labels=labels,
                            p=p, q=q,
                            a=a, b=b,
                            C=C,
                            snr_target=snr_target,
                            snr_real=snr_real,
                            gamma=gamma,
                            class_sizes=np.array(class_sizes, dtype=np.int32),
                            eigvecs_top=eigvecs_top  # 若体积大可去掉
                        )
                        idx += 1

        print(f"[{mode}] 网格数据完成: 共 {idx} 张（期望 {num_graphs_expected}）。目录: {directory}")

    def copy(self):
        return copy.deepcopy(self)